Explanation methods¶

Deep learning models are becoming better and better at making predictions. As researchers, regulators, and users, we are also interested in asking additional questions. Namely, we would like to explain a decision in terms of the input. Where in an image is a model focusing on? What cues is the prediction based on? Ddoes it match our expectation? Can the model be trusted?

In this practical, we will explore popular methods for explaining decisions made by image classifiers:

  • Simple occlusion
  • Gradient norm
  • Gradient x input
  • GradCAM
  • Integrated gradients

With a working implementation of each method, we will compare explanations qualitatively on a few sample images.

Furthermore, we will evaluate the correctness of each method quantitatively using the deletion score.

Setup¶

In [1]:
!pip install "jax[cuda]" -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'

!pip install \
  flax optax \
  'git+https://github.com/n2cholas/jax-resnet.git' \
  tensorflow-datasets \
  better_exceptions
Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda]
  Downloading jax-0.6.2-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.6.2,>=0.6.2 (from jax[cuda])
  Downloading jaxlib-0.6.2-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.3 kB)
Collecting ml_dtypes>=0.5.0 (from jax[cuda])
  Downloading ml_dtypes-0.5.3-cp310-cp310-macosx_10_9_universal2.whl.metadata (8.9 kB)
Requirement already satisfied: numpy>=1.26 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.26.4)
Requirement already satisfied: opt_einsum in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (3.4.0)
Requirement already satisfied: scipy>=1.12 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.15.3)
INFO: pip is looking at multiple versions of jax[cuda] to determine which version is compatible with other requirements. This could take a while.
Collecting jax[cuda]
  Downloading jax-0.6.1-py3-none-any.whl.metadata (13 kB)
Collecting jaxlib<=0.6.1,>=0.6.1 (from jax[cuda])
  Downloading jaxlib-0.6.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB)
Collecting jax[cuda]
  Downloading jax-0.6.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.6.0,>=0.6.0 (from jax[cuda])
  Downloading jaxlib-0.6.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB)
Collecting jax[cuda]
  Downloading jax-0.5.3-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.3,>=0.5.3 (from jax[cuda])
  Downloading jaxlib-0.5.3-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.2 kB)
Collecting jax[cuda]
  Downloading jax-0.5.2-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.2,>=0.5.1 (from jax[cuda])
  Downloading jaxlib-0.5.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (978 bytes)
Collecting jax[cuda]
  Downloading jax-0.5.1-py3-none-any.whl.metadata (22 kB)
  Downloading jax-0.5.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.5.0,>=0.5.0 (from jax[cuda])
  Downloading jaxlib-0.5.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (978 bytes)
Collecting jax[cuda]
  Downloading jax-0.4.38-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.38,>=0.4.38 (from jax[cuda])
  Downloading jaxlib-0.4.38-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB)
INFO: pip is still looking at multiple versions of jax[cuda] to determine which version is compatible with other requirements. This could take a while.
Collecting jax[cuda]
  Downloading jax-0.4.37-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.37,>=0.4.36 (from jax[cuda])
  Downloading jaxlib-0.4.36-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB)
Collecting jax[cuda]
  Downloading jax-0.4.36-py3-none-any.whl.metadata (22 kB)
  Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.35,>=0.4.34 (from jax[cuda])
  Downloading jaxlib-0.4.35-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
  Downloading jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
Collecting jax[cuda]
  Downloading jax-0.4.34-py3-none-any.whl.metadata (22 kB)
Requirement already satisfied: ml-dtypes>=0.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (0.3.2)
  Downloading jax-0.4.33-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.33,>=0.4.33 (from jax[cuda])
  Downloading jaxlib-0.4.33-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
INFO: This is taking longer than usual. You might need to provide the dependency resolver with stricter constraints to reduce runtime. See https://pip.pypa.io/warnings/backtracking for guidance. If you want to abort this run, press Ctrl + C.
Collecting jax[cuda]
  Downloading jax-0.4.31-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.31,>=0.4.30 (from jax[cuda])
  Downloading jaxlib-0.4.31-cp310-cp310-macosx_11_0_arm64.whl.metadata (983 bytes)
Collecting jax[cuda]
  Using cached jax-0.4.30-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.30,>=0.4.27 (from jax[cuda])
  Downloading jaxlib-0.4.30-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.0 kB)
Collecting jax[cuda]
  Using cached jax-0.4.29-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.28-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.27-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.26-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.25-py3-none-any.whl.metadata (24 kB)
  Using cached jax-0.4.24-py3-none-any.whl.metadata (24 kB)
  Using cached jax-0.4.23-py3-none-any.whl.metadata (24 kB)
  Using cached jax-0.4.22-py3-none-any.whl.metadata (24 kB)
  Using cached jax-0.4.21-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.20-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.19-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.18-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.17-py3-none-any.whl.metadata (23 kB)
  Using cached jax-0.4.16-py3-none-any.whl.metadata (29 kB)
  Using cached jax-0.4.14.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.13.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.12.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.11.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.10.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.9.tar.gz (1.3 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.8.tar.gz (1.2 MB)
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Preparing metadata (pyproject.toml) ... done
  Using cached jax-0.4.7.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.6.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.5.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.4.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.3.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.2.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.4.1.tar.gz (1.2 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.25.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: typing_extensions in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (4.15.0)
  Using cached jax-0.3.24.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.23.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: absl-py in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (2.3.1)
Requirement already satisfied: etils[epath] in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax[cuda]) (1.13.0)
  Using cached jax-0.3.22.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.21.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.20.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.19.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.17.tar.gz (1.1 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.16.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.15.tar.gz (1.0 MB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.14.tar.gz (990 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.13.tar.gz (951 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.12.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.11.tar.gz (947 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.10.tar.gz (939 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.9.tar.gz (937 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.8.tar.gz (935 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.7.tar.gz (944 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.6.tar.gz (936 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.5.tar.gz (946 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.4.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.3.tar.gz (924 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.2.tar.gz (926 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.1.tar.gz (912 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.3.0.tar.gz (896 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.28.tar.gz (887 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.27.tar.gz (873 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.26.tar.gz (850 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.25.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.24.tar.gz (786 kB)
  Preparing metadata (setup.py) ... done
  Using cached jax-0.2.22.tar.gz (776 kB)
  Preparing metadata (setup.py) ... done
WARNING: jax 0.2.22 does not provide the extra 'cuda'
Building wheels for collected packages: jax
  DEPRECATION: Building 'jax' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'jax'. Discussion can be found at https://github.com/pypa/pip/issues/6334
  Building wheel for jax (setup.py) ... done
  Created wheel for jax: filename=jax-0.2.22-py3-none-any.whl size=890324 sha256=d8f8654332391b7d5273b9106ae0648bedbd3b86ff6a2ad821d0041fb259fafa
  Stored in directory: /Users/silpasoninallacheruvu/Library/Caches/pip/wheels/07/6c/f6/11dc726435faa88188b1f08d34780c161bb9eb966f3a5a01a7
Successfully built jax
Installing collected packages: jax
Successfully installed jax-0.2.22
Collecting git+https://github.com/n2cholas/jax-resnet.git
  Cloning https://github.com/n2cholas/jax-resnet.git to /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-req-build-wrhzyzg8
  Running command git clone --filter=blob:none --quiet https://github.com/n2cholas/jax-resnet.git /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-req-build-wrhzyzg8
  Resolved https://github.com/n2cholas/jax-resnet.git to commit 5b00735aa0a68ec239af4a728ad4a596c1b551f6
  Preparing metadata (setup.py) ... done
Collecting flax
  Downloading flax-0.10.7-py3-none-any.whl.metadata (11 kB)
Collecting optax
  Downloading optax-0.2.6-py3-none-any.whl.metadata (7.6 kB)
Requirement already satisfied: tensorflow-datasets in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (4.9.9)
Collecting better_exceptions
  Using cached better_exceptions-0.3.3-py3-none-any.whl.metadata (466 bytes)
Requirement already satisfied: jax in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax-resnet==0.0.4) (0.2.22)
Collecting jaxlib (from jax-resnet==0.0.4)
  Using cached jaxlib-0.6.2-cp310-cp310-macosx_11_0_arm64.whl.metadata (1.3 kB)
Collecting jax (from jax-resnet==0.0.4)
  Using cached jax-0.6.2-py3-none-any.whl.metadata (13 kB)
Collecting msgpack (from flax)
  Downloading msgpack-1.1.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (8.4 kB)
Collecting orbax-checkpoint (from flax)
  Downloading orbax_checkpoint-0.11.25-py3-none-any.whl.metadata (2.3 kB)
Collecting tensorstore (from flax)
  Downloading tensorstore-0.1.77-cp310-cp310-macosx_11_0_arm64.whl.metadata (21 kB)
Requirement already satisfied: rich>=11.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (14.1.0)
Requirement already satisfied: typing_extensions>=4.2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (4.15.0)
Requirement already satisfied: PyYAML>=5.4.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from flax) (6.0.3)
Collecting treescope>=0.1.7 (from flax)
  Downloading treescope-0.1.10-py3-none-any.whl.metadata (6.6 kB)
Requirement already satisfied: absl-py>=0.7.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (2.3.1)
Collecting chex>=0.1.87 (from optax)
  Using cached chex-0.1.90-py3-none-any.whl.metadata (18 kB)
Requirement already satisfied: numpy>=1.18.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from optax) (1.26.4)
Requirement already satisfied: dm-tree in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.1.9)
Requirement already satisfied: etils>=1.6.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (1.13.0)
Requirement already satisfied: immutabledict in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.2.1)
Requirement already satisfied: promise in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (2.3)
Requirement already satisfied: protobuf>=3.20 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.21.12)
Requirement already satisfied: psutil in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (7.1.0)
Requirement already satisfied: pyarrow in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (21.0.0)
Requirement already satisfied: requests>=2.19.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (2.32.5)
Requirement already satisfied: simple_parsing in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.1.7)
Requirement already satisfied: tensorflow-metadata in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (1.17.2)
Requirement already satisfied: termcolor in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (3.1.0)
Requirement already satisfied: toml in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (0.10.2)
Requirement already satisfied: tqdm in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (4.67.1)
Requirement already satisfied: wrapt in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from tensorflow-datasets) (1.17.3)
Collecting toolz>=0.9.0 (from chex>=0.1.87->optax)
  Using cached toolz-1.0.0-py3-none-any.whl.metadata (5.1 kB)
Requirement already satisfied: einops in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (0.8.1)
Requirement already satisfied: fsspec in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (2025.9.0)
Requirement already satisfied: importlib_resources in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (6.5.2)
Requirement already satisfied: zipp in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from etils[edc,enp,epath,epy,etree]>=1.6.0; python_version < "3.11"->tensorflow-datasets) (3.23.0)
Collecting ml_dtypes>=0.5.0 (from jax->jax-resnet==0.0.4)
  Using cached ml_dtypes-0.5.3-cp310-cp310-macosx_10_9_universal2.whl.metadata (8.9 kB)
Requirement already satisfied: opt_einsum in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (3.4.0)
Requirement already satisfied: scipy>=1.12 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from jax->jax-resnet==0.0.4) (1.15.3)
Requirement already satisfied: charset_normalizer<4,>=2 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (3.4.3)
Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from requests>=2.19.0->tensorflow-datasets) (2025.8.3)
Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from rich>=11.1->flax) (4.0.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from rich>=11.1->flax) (2.19.2)
Requirement already satisfied: mdurl~=0.1 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2)
Requirement already satisfied: attrs>=18.2.0 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from dm-tree->tensorflow-datasets) (25.3.0)
Requirement already satisfied: nest_asyncio in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.6.0)
Collecting aiofiles (from orbax-checkpoint->flax)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting humanize (from orbax-checkpoint->flax)
  Using cached humanize-4.13.0-py3-none-any.whl.metadata (7.8 kB)
Collecting simplejson>=3.16.0 (from orbax-checkpoint->flax)
  Downloading simplejson-3.20.2-cp310-cp310-macosx_11_0_arm64.whl.metadata (3.4 kB)
Requirement already satisfied: six in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from promise->tensorflow-datasets) (1.17.0)
Requirement already satisfied: docstring-parser<1.0,>=0.15 in /opt/homebrew/Caskroom/miniconda/base/envs/dd2412/lib/python3.10/site-packages (from simple_parsing->tensorflow-datasets) (0.17.0)
Downloading flax-0.10.7-py3-none-any.whl (456 kB)
Downloading optax-0.2.6-py3-none-any.whl (367 kB)
Using cached better_exceptions-0.3.3-py3-none-any.whl (11 kB)
Using cached chex-0.1.90-py3-none-any.whl (101 kB)
Downloading jax-0.6.2-py3-none-any.whl (2.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.7/2.7 MB 11.0 MB/s  0:00:00 11.5 MB/s eta 0:00:01
Downloading jaxlib-0.6.2-cp310-cp310-macosx_11_0_arm64.whl (54.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.3/54.3 MB 11.7 MB/s  0:00:04a 0:00:01[36m0:00:01:01
Downloading ml_dtypes-0.5.3-cp310-cp310-macosx_10_9_universal2.whl (667 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 667.4/667.4 kB 9.6 MB/s  0:00:00
Using cached toolz-1.0.0-py3-none-any.whl (56 kB)
Downloading treescope-0.1.10-py3-none-any.whl (182 kB)
Downloading msgpack-1.1.1-cp310-cp310-macosx_11_0_arm64.whl (78 kB)
Downloading orbax_checkpoint-0.11.25-py3-none-any.whl (563 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 563.1/563.1 kB 3.2 MB/s  0:00:00
Downloading simplejson-3.20.2-cp310-cp310-macosx_11_0_arm64.whl (76 kB)
Downloading tensorstore-0.1.77-cp310-cp310-macosx_11_0_arm64.whl (13.8 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 8.2 MB/s  0:00:01 eta 0:00:01[36m0:00:01
Downloading aiofiles-24.1.0-py3-none-any.whl (15 kB)
Using cached humanize-4.13.0-py3-none-any.whl (128 kB)
Building wheels for collected packages: jax-resnet
  DEPRECATION: Building 'jax-resnet' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'jax-resnet'. Discussion can be found at https://github.com/pypa/pip/issues/6334
  Building wheel for jax-resnet (setup.py) ... done
  Created wheel for jax-resnet: filename=jax_resnet-0.0.4-py2.py3-none-any.whl size=11972 sha256=c05f1546fe2444af8d534b2fceb6ffe850111352df782eb3a938788e235cbc85
  Stored in directory: /private/var/folders/mj/bwns80rs3psccqv0gz4hh8zc0000gn/T/pip-ephem-wheel-cache-xxf9xpr2/wheels/2b/57/8c/a9e9b5ae55d9dfc4466c910140d5625f44eb779908cc868b2d
Successfully built jax-resnet
Installing collected packages: better_exceptions, treescope, toolz, simplejson, msgpack, ml_dtypes, humanize, aiofiles, tensorstore, jaxlib, jax, orbax-checkpoint, chex, optax, flax, jax-resnet
  Attempting uninstall: ml_dtypes
    Found existing installation: ml-dtypes 0.3.2
    Uninstalling ml-dtypes-0.3.2:
      Successfully uninstalled ml-dtypes-0.3.2
  Attempting uninstall: jax━━━━━━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━  9/16 [jaxlib]es]
    Found existing installation: jax 0.2.22╸━━━━━━━━━━━━━━━━━  9/16 [jaxlib]
    Uninstalling jax-0.2.22:━━━━━━━━━━━╸━━━━━━━━━━━━━━━━━  9/16 [jaxlib]
      Successfully uninstalled jax-0.2.220m╸━━━━━━━━━━━━━━━━━  9/16 [jaxlib]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 16/16 [jax-resnet]0m 14/16 [flax] [optax]checkpoint]
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.16.2 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.5.3 which is incompatible.
Successfully installed aiofiles-24.1.0 better_exceptions-0.3.3 chex-0.1.90 flax-0.10.7 humanize-4.13.0 jax-0.6.2 jax-resnet-0.0.4 jaxlib-0.6.2 ml_dtypes-0.5.3 msgpack-1.1.1 optax-0.2.6 orbax-checkpoint-0.11.25 simplejson-3.20.2 tensorstore-0.1.77 toolz-1.0.0 treescope-0.1.10
In [2]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

import tensorflow as tf

tf.get_logger().setLevel("WARNING")
tf.config.experimental.set_visible_devices([], "GPU")

from collections import defaultdict
from functools import partial
from typing import Sequence

import flax.core
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax_resnet
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
import sklearn.metrics
import tabulate
import tensorflow_datasets as tfds
import torch
import tqdm
from flax.training.train_state import TrainState
from IPython.display import display
from jax import jit, vmap

RED = np.array([1.0, 0, 0])
BLUE = np.array([0, 0, 1.0])


@jax.jit
def normalize_zero_one(x):
    """Normalize a vector between 0 and 1."""
    res = (x - x.min()) / (x.max() - x.min())
    res = jnp.clip(res, a_min=0, a_max=1)
    return res


@jax.jit
def normalize_max(x):
    """Normalize a vector between -1 and 1."""
    res = x / jnp.abs(x).max()
    res = jnp.clip(res, a_min=-1, a_max=1)
    return res


@jax.jit
def blend(a, b, alpha: float):
    """Blend two float-valued images"""
    return (1 - alpha) * a + alpha * b

Dataset¶

For simplicity, we will use the small ImageNette dataset that contains 10 easy-to-classify categories from ImageNet.

Here we load the dataset and show a few images that will be used throughout this notebook.

In [33]:
CLASS_NAMES = [
    "tench",
    "English springer",
    "cassette player",
    "chain saw",
    "church",
    "French horn",
    "garbage truck",
    "gas pump",
    "golf ball",
    "parachute",
]


def show_images(images, labels=None, logits=None, ncols=4, width_one_img_inch=3.0):
    B, H, W, *_ = images.shape
    nrows = int(np.ceil(B / ncols))
    fig, axs = plt.subplots(
        nrows,
        ncols,
        figsize=width_one_img_inch * np.array([1, H / W]) * np.array([ncols, nrows]),
        sharex=True,
        sharey=True,
        squeeze=False,
        facecolor="white",
    )
    for b in range(B):
        ax = axs.flat[b]
        ax.imshow(images[b])
        if labels is not None:
            ax.set_title(CLASS_NAMES[labels[b]])
        if logits is not None:
            pred = logits[b].argmax()
            prob = nn.softmax(logits[b])[pred]
            color = (
                "blue" if labels is None else ("green" if labels[b] == pred else "red")
            )
            p = mpl.patches.Patch(color=color, label=f"{prob:.2%} {CLASS_NAMES[pred]}")
            ax.legend(handles=[p])
    fig.tight_layout()
    display(fig)
    plt.close(fig)


def resize(image, label):
    image = tf.image.resize_with_pad(image, 224, 224)
    return image / 255.0, label


ds_builder = tfds.builder("imagenette/320px-v2", data_dir=".")
ds_builder.download_and_prepare()

total_images = ds_builder.info
print(f"Total images: {total_images}")

ds = ds_builder.as_dataset(split="train", batch_size=None, as_supervised=True)
ds = ds.map(resize)
ds = ds.batch(8)
ds = tfds.as_numpy(ds)
viz_batch = next(iter(ds))

images, labels = viz_batch
show_images(images, labels)
Total images: tfds.core.DatasetInfo(
    name='imagenette',
    full_name='imagenette/320px-v2/1.0.0',
    description="""
    Imagenette is a subset of 10 easily classified classes from the Imagenet
    dataset. It was originally prepared by Jeremy Howard of FastAI. The objective
    behind putting together a small version of the Imagenet dataset was mainly
    because running new ideas/algorithms/experiments on the whole Imagenet take a
    lot of time.
    
    This version of the dataset allows researchers/practitioners to quickly try out
    ideas and share with others. The dataset comes in three variants:
    
    *   Full size
    *   320 px
    *   160 px
    
    Note: The v2 config correspond to the new 70/30 train/valid split (released in
    Dec 6 2019).
    """,
    config_description="""
    320px variant.
    """,
    homepage='https://github.com/fastai/imagenette',
    data_dir='imagenette/320px-v2/1.0.0',
    file_format=tfrecord,
    download_size=325.84 MiB,
    dataset_size=332.71 MiB,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=10),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    nondeterministic_order=False,
    splits={
        'train': <SplitInfo num_examples=9469, num_shards=2>,
        'validation': <SplitInfo num_examples=3925, num_shards=1>,
    },
    citation="""@misc{imagenette,
      author    = "Jeremy Howard",
      title     = "imagenette",
      url       = "https://github.com/fastai/imagenette/"
    }""",
)
No description has been provided for this image
In [4]:
def load_resnet(size):
    """Load a resnet model and return resnet_logits_fn and its variables.

    Returns:
        logits_fn: a jitted function that given one image applies
                   the resnet model and returns the max logit
                   value and the logits vector
        variables: resnet variables to use with logits_fn
    """

    def logits_fn(variables, img):
        # img: [H, W, C], float32 in range [0, 1]
        #print(f"img shape in logits_fn:{img.shape}")
        assert img.ndim == 3
        img = normalize_for_resnet(img)
        logits = model.apply(variables, img[None, ...])[0]
        logits = imagenet_to_imagenette_logits(logits)
        return logits.max(), logits


    ResNet, variables = jax_resnet.pretrained_resnet(size)
    model = ResNet()
    logits_fn = jax.jit(logits_fn)
    return logits_fn, variables


def normalize_for_resnet(image):
    mean = jnp.array([0.485, 0.456, 0.406])
    std = jnp.array([0.229, 0.224, 0.225])
    return (image - mean) / std


def imagenet_to_imagenette_logits(logits):
    """Select the 10 imagenette classes from the 1000 imagenet classes."""
    return logits[..., [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]]


logits_fn, variables = load_resnet(size=18)
images, labels = viz_batch
print(f"images:{images.shape}")
_, logits = jax.vmap(logits_fn, (None, 0))(variables, images)
print(f"logits shape: {logits.shape}")
print(f"images shape: {(images.shape[0], len(CLASS_NAMES))}")
assert logits.shape == (images.shape[0], len(CLASS_NAMES))
show_images(images, labels, logits)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
images:(20, 224, 224, 3)
logits shape: (20, 10)
images shape: (20, 10)
No description has been provided for this image

Pretrained ResNet¶

We will focus on a ResNet 18 model for the explanations which has been ported from PyTorch thanks to this repo.

The simplest way to load and run a ResNet model using jax_resnet is:

ResNet, variables = jax_resnet.pretrained_resnet(size)
model = ResNet()
img = jnp.zeros(224, 224, 3)                        # [H, W, C]
logits = model.apply(variables, img[None, ...])[0]  # [1000]

Task 1¶

Here we load a pre-trained model and prepare it for our purposes. We want the following:

  1. The function should operate on a single image instead of a batch. Altough counterintuitive, this will make it easier to reason about explanations later and is more in tune with the philosophy of jax.
  2. The function should take care of normalizing the image with mean [0.485, 0.456, 0.406] and std [0.229, 0.224, 0.225] as done for the PyTorch models that this model was converted from. Refer to torchvision.transforms.Normalize for an example.
  3. Select out of the 1000 ImageNet classes the 10 ImageNette classes that we are interested in.
  4. The function should return the largest element of the 10-dimensional logits vector, since later on we'll often compute gradients of it. The full logits vector should also be returned for prediction and visualization purposes.

Complete the function logits_fn returned by load_resnet so that it fullfills the requirements above. Upon executing the cell you should see 7/8 correct predictions with almost-certain confidence.

Explanation methods¶

Occlusion¶

The simplest explanation method consists in removing patches of the input image and measuring the effect on prediction confidence. Specifically, we want to measure the drop (or increase) in confidence in the predicted class between the original non-occluded image and an occluded version.

We will use a single square patch of fixed size that is scanned over the entire image without overlap, altough it would be possible to come up with more advanced patterns of occlusion.

Task 2¶

Complete the function prepare_occlusions that takes in a single image of shape [H, W, 3] and outputs a batch of images of shape [S, S, H, W, 3] where the image at [i, j] contains a black patch of size [H/S, W/S] whose top-left corner is placed at [i*H/S, j*W/S].

Explained with a drawing:

imgs[i, j] =
                j*W/S
      ┌───────────┬────┬────┐
      │           |    |    │
      │           |    |    │
i*H/S ├ ─ ─ ─ ─ ─ ┼────┤    │
      │           │####│    │
      │           │####│    │
      ├ ─ ─ ─ ─ ─ ┴────┘    │
      │                     │
      │                     │
      │                     │
      │                     │
      │                     │
      └─────────────────────┘

Remember that in jax arrays can not be modified in-place. Use at[].set() instead:

x[idx] = y  # Bad
x = x.at[idx].set(y)  # Good

Once the missing lines in prepare_occlusions are filled in, visualize the resulting batch of partially-occluded images to check your implementation.

In [5]:
def prepare_occlusions(img, steps: int):
    H, W, _ = img.shape
    imgs = jnp.tile(img, (steps, steps, 1, 1, 1))
    for i in range(0, steps):
        for j in range(0, steps):
            imgs = imgs.at[i, j, int(i*H/steps):int((i+1)*H/steps), int(j*W/steps):int((j+1)*W/steps), :].set(0)

    print(f"shape of imgs:{imgs.shape}")
    # imgs: [steps, steps, H, W, 3]
    return imgs


prepare_occlusions = jax.jit(prepare_occlusions, static_argnames="steps")

show_images(
    prepare_occlusions(viz_batch[0][0], steps=3).reshape(-1, 224, 224, 3),
    ncols=3,
    width_one_img_inch=1.5,
)
shape of imgs:(3, 3, 224, 224, 3)
No description has been provided for this image

Task 3¶

Using prepare_occlusions implemented above, complete the missing lines in occlusion_fn following to this pseudo-code:

probs = f(img)
idx = argmax(probs)
imgs = prepare_occlusions(img)
relevance[i, j] = f(img)[idx] - f(imgs[i, j])[idx]
relevance = resize(relevance, img.shape)

With a working implementation, the code below will show positive and negative attributions for eight images. Positive attribution is shown as a red overlay, while negative attribution is shown in blue (almost invisible except for the last image).

Note: jit and vmap take care of speeding up and vectorizing occlusion_fn so that it works on a batch of images. You will see them used as wrappers or decorators throughout the notebook.

Tips:

  • you want to compute how much the probability of the original prediction drops, apply softmax to the output of logits_fn to get probabilities and select the right class with idx
  • apply vmap twice to logits_fn to vectorize it over the two extra axes added by prepare_occlusions, you don't need two nested for loops
  • use jax.image.resize with method="bilinear" to resize the heatmap to the original size
  • use normalize_max to rescale the attributions to a range that works well with the visualization code
In [7]:
    
def occlusion_fn(logits_fn, variables, img, steps: int):
    H, W, _ = img.shape
    _, logits_orig = logits_fn(variables, img)
    probs = nn.softmax(logits_orig)
    #print(f"probs:{probs.shape}")
    idx = logits_orig.argmax()
    #print(f"idx:{idx.shape}")
    imgs = prepare_occlusions(img, steps)
    logits_occ_fn = jax.vmap(                                   
        jax.vmap(logits_fn, (None,0)),
        (None,0)
    )
    _, logits_occ = logits_occ_fn(variables, imgs)
    probs_occ = nn.softmax(logits_occ, axis=-1)
    relevance = probs[idx] - probs_occ[..., idx]
    #print(f"relevance:{relevance.shape}")
    relevance = jax.image.resize(relevance, (H, W), method="bilinear")
    attrib = normalize_max(relevance)
    #print(f"relevance:{attrib.shape}")
    # logits_orig: [num_classes]
    # attrib:      [H, W]
    return logits_orig, attrib


occlusion_fn = jax.jit(occlusion_fn, static_argnames=["logits_fn", "steps"])
occlusion_fn = jax.vmap(occlusion_fn, in_axes=(None, None, 0, None))

images, labels = viz_batch
logits_fn, variables = load_resnet(size=18)
logits, relevance = occlusion_fn(logits_fn, variables, images, 6)

images = blend(images, RED, jnp.clip(relevance, a_min=0)[..., None])
images = blend(images, BLUE, -jnp.clip(relevance, a_max=0)[..., None])
show_images(images, labels, logits)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
No description has been provided for this image

Grad norm (sensitivity)¶

An important tool for decision explanation is the gradient of the prediction function with respect to the input variable evaluated at the input image. Intuitively, the gradient expresses how much a change in the input would affect the prediction (actually the pre-softmax confidence). By evaluating the gradient at the input image, we can estimate the relevance $R_i$ of each pixel $i$.

$$ \begin{align} X &\in \mathbb{R}^{D} \\ p &= f(X) \\ R_i &= \nabla f(X_i) \end{align} $$

Since our models operate on images, the gradient w.r.t. an image will have shape [H, W, 3]. For ease of visualization, we will compute the norm of the gradient at each pixel location and visualize it as a heatmap of shape [H, W].

Task 4¶

Complete the missing lines of grad_norm_fn so that given a single input image it returns the associated logits and the pixel-wise norm of the gradient of the most confident prediction.

Also, since we want to overlay the explanation to the image, make sure to scale the results in the range [0, 1].

Tips:

  • The function logits_fn prepared in task 1 returns the maximum logit as its first return value.
  • In jax one can use jax.value_and_grad to decorate a function so that both the value and its gradient are returned.
  • The function jax.value_and_grad can also take an extra parameter has_aux to indicate that the original function returns more than one value and that those extra values should be returned by the decorated function too. Example:
    def foo(a, x):
        y = jnp.exp(x**2) - jnp.sin(a @ x)
        return y.sum(), y
    
    foo_vg = jax.value_and_grad(foo, argnums=1, has_aux=True)
    (y_sum, y), grad_x = foo_vg(a, x)
    
  • use normalize_max to rescale the attributions to a range that works well with the visualization code
In [8]:
def grad_norm_fn(logits_fn, variables, img):
    #print(f"img in grad_norm_fn:{img.shape}")
    H, W, _ = img.shape
    logits_vg_fn = jax.value_and_grad(logits_fn, argnums=1, has_aux=True)
    (_, logits), grads = logits_vg_fn(variables, img)
    #print(f"grads:{grads.shape}")
    heat = jnp.linalg.norm(grads, axis=-1) 
    grad = normalize_max(heat)
    #print(f"grad:{grad.shape}")
    # logits: [num_classes]
    # grad:   [H, W]
    return logits, grad


grad_norm_fn = jax.jit(grad_norm_fn, static_argnames=["logits_fn"])
grad_norm_fn = jax.vmap(grad_norm_fn, in_axes=(None, None, 0))

images, labels = viz_batch
logits_fn, variables = load_resnet(size=18)
logits, relevance = grad_norm_fn(logits_fn, variables, images)

show_images(
    # images * relevance[..., None],
    blend(images, RED, relevance[..., None]),
    labels,
    logits,
)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
No description has been provided for this image

Grad x input¶

To increase the sharpness of the explanations it's possible to multiply the value of the gradient with the corresponding input. Intuitively, the gradient expresses the importance of a certain feature and is now rescaled by how much that feature is present.

$$R_i = X_i \cdot \nabla f(X_i)$$

Task 5¶

Modify grad_norm_fn so that the gradient is multiplied with the image before computing the norm.

In [9]:
def grad_x_input_fn(logits_fn, variables, img):
    H, W, _ = img.shape
    logits_vg_fn = jax.value_and_grad(logits_fn, argnums=1, has_aux=True)
    (_, logits), grads = logits_vg_fn(variables, img)
    #print(f"grads:{grads.shape}")
    grads_x = img * grads
    heat_x = jnp.linalg.norm(grads_x, axis=-1) 
    grad = normalize_max(heat_x)
    #print(f"grad:{grad.shape}")
    # logits: [num_classes]
    # grad:   [H, W]
    return logits, grad


grad_x_input_fn = jax.vmap(grad_x_input_fn, in_axes=(None, None, 0))
grad_x_input_fn = jax.jit(grad_x_input_fn, static_argnames=["logits_fn"])

images, labels = viz_batch
logits_fn, variables = load_resnet(size=18)
logits, relevance = grad_x_input_fn(logits_fn, variables, images)

show_images(
    # images * relevance[..., None],
    blend(images, RED, relevance[..., None]),
    labels,
    logits,
)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
No description has been provided for this image

Integrated gradients¶

As observed, gradient-based explanations appear very noisy. This is because gradients can only describe what happens in the local neighborhood of the input image when a pixel is changed by a small quantity, therefore:

  • some pixels might be very important for the prediction (e.g. the color of a flower), but the local gradient might be saturated (e.g. different shades of yellow will all give the same confidence in "sunflower") and therefore those pixels will not be marked as relevant.
  • a small step in the direction of the gradient might increase the prediction confidence, but a slightly larger step might decrease it further and an even slighly larger step might increase it again.

The integrated gradients method proposes to address this issue by aggregating gradients along a linear path between the input image and a baseline (usually black). By considering a path, the noise associated to local gradients is reduced. Also, using a baseline image allows to express the explanation in relative terms rather than absolute.

By expressing the path as $\gamma(\alpha) = B + \alpha(X-B)$, the method can be expressed as: $$ R_i = \int_0^1 \frac{\partial f(\gamma(\alpha))}{\partial\gamma_i(\alpha)} \frac{\partial\gamma_i(\alpha)}{\partial\alpha} \ d\alpha. $$

Which can be approximated as: $$ R_i \approx (X_i - B_i) \frac{1}{M} \sum_{m=1}^M \frac{\partial f\left(B + m/M(X-B)\right)}{\partial X_i}. $$

Where $B$ indicates the black baseline, $M$ is the number of steps for approximating the path integral.

The function below computes all intermediate images between an input img and a black baseline.

In [10]:
@partial(jax.jit, static_argnames=["steps"])
def prepare_integrated_gradients(img, steps: int):
    assert img.ndim == 3
    return img[None, :, :, :] * jnp.linspace(1, 0, num=steps)[:, None, None, None]


image = viz_batch[0][0]
images = prepare_integrated_gradients(image, steps=8).reshape(-1, 224, 224, 3)
show_images(images, width_one_img_inch=2)
No description has been provided for this image

Task 6¶

Complete the function integrated_grad_fn so that:

  • a single image is taken as input
  • a prediction is made on the input image to determine the predicted class
  • a batch of progressively darker images is prepared with prepare_integrated_gradients
  • for each image in the batch, the gradients of the logit at idx is computed
  • the path integral is approximated using a finite sum

According to the official implementation, only positive attributions are considered and attributions are averaged per pixel.

Tips:

  • Store the index of the most-confident prediction for the input image as idx because we need to refer to it when computing gradients
  • At each intermediate step you don't want the gradient max_logit, i.e. the first output of logits_fn, which you would get from grad(logits_fn). Instead you want the gradient of the idx-th element of logits, i.e. the second output. Define a local function or a lambda and call grad on that.
  • You don't need for loops, use vmap
In [12]:
def integrated_grad_fn(logits_fn, variables, img, steps: int):
    H, W, _ = img.shape
    # model's predicted class
    _, logits_orig = logits_fn(variables, img)
    idx = logits_orig.argmax()
    #print(f"idx:{idx.shape}")
    baseline = jnp.zeros_like(img) 
    images = prepare_integrated_gradients(img, steps).reshape(-1, H, W, 3)
    #print(f"images:{images.shape}")
    _, logits = jax.vmap(logits_fn, (None, 0))(variables, images)
    # function to call grad on idx-th element of logits
    def grads_idx_fn(variables, img_):
        logit_max, logit = logits_fn(variables, img_)
        val = logit[idx]
        return val, logit_max
    value_and_grad_fn = jax.value_and_grad(grads_idx_fn, argnums=1, has_aux=True)
    (_,_), grads = jax.vmap(lambda im: value_and_grad_fn(variables, im), in_axes=0)(images)
    avg_grads = grads.mean(axis=0)
    ig = (img - baseline) * avg_grads
    heat = jnp.linalg.norm(ig, axis=-1)
    grads = normalize_max(heat)
    #print(f"logits_orig:{logits_orig.shape}")
    #print(f"grads:{grads.shape}")
    # logits: [num_classes]
    # grads:  [H, W]
    return logits_orig, grads


integrated_grad_fn = jax.jit(integrated_grad_fn, static_argnames=["logits_fn", "steps"])
integrated_grad_fn = jax.vmap(integrated_grad_fn, in_axes=(None, None, 0, None))

images, labels = viz_batch
logits_fn, variables = load_resnet(size=18)
logits, relevance = integrated_grad_fn(logits_fn, variables, images, 25)

show_images(
    # images * relevance[..., None],
    blend(images, RED, relevance[..., None]),
    labels,
    logits,
)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
No description has been provided for this image

GradCAM¶

GradCAM decomposes the ResNet model in two blocks: a CNN backbone and a linear classifier, separated by global average pooling. $$ \begin{align} X &\in \mathbb{R}^{H\times W\times 3} \\ A &= \text{Backbone}(X) \in \mathbb{R}^{H'\times W'\times K}\\ Y &= \text{Linear}(\text{GAP}(A)) \in \mathbb{R}^C \end{align} $$

The main idea is to consider the gradient of the activations before global average pooling and use them to rescale the intensity of the associated feature maps. Specifically, if the predicted class is $c$, the scaling factor for the $k$-th feature map is: $$ \alpha_c^k = \frac{1}{H'W'} \sum_i^{H'}\sum_j^{W'} \frac{\partial Y^c}{\partial A^k_{i,j}} $$

The feature maps are then combined into a sized-down attribution as: $$ R_c = \text{ReLU}\left( \sum_k^K \alpha_c^k A^k\right) \in \mathbb{R}^{H'\times W'} $$

Finally, the relevance heatmap is resized to match the input image. Compared to the gradient-based methods above, GradCAM produces much smoother heatmaps thanks to this upsampling operation.

Task 7¶

Implement the missing parts of grad_cam_fn:

  • First, process the image through the backbone
  • Then, process the features through global average pooling the through the classifier. Remember that you'll need the gradients w.r.t. these features.
  • Once you have the gradients, combine them with the features as indicated above and resize the relevance to the same size of the input image.
  • As usual, return both the logits vector and the relevance matrix.

Tips:

  • You will need to apply the backbone and the classifier separately. The function load_resnet_for_grad_cam takes care of splitting the model and its variables for you.
    They are returned as two dictionaries, both containing the keys backbone and gap_cls.
  • When performing complex sums and products jnp.einsum can drastically simplify the amount of error-prone reshaping code required.
In [26]:
def grad_cam_fn(fns, variables, img):
    H, W, _ = img.shape
    backbone_fn = fns["backbone"]
    gap_classifier_fn  = fns["gap_cls"]
    backbone_vars = variables["backbone"]
    gap_classifier_vars  = variables["gap_cls"]

    # apply image through backbone 
    features = backbone_fn(backbone_vars, img)
    _, logits = gap_classifier_fn(gap_classifier_vars, features)
    # fix the target class c
    c = jnp.argmax(logits)   

    # scalar logit function for class c
    def class_logit_fn(vars_, feats_):
        _, logits = gap_classifier_fn(vars_, feats_)
         # scalar Y^c
        return logits[c]

    #  gradients wrt features for class c
    vgf = jax.value_and_grad(class_logit_fn, argnums=1, has_aux=False)
    _, grads = vgf(gap_classifier_vars, features) 
    #gap_classifier_grad_fn = jax.value_and_grad(class_logit_fn, argnums=1, has_aux=True)
    #_, grads = gap_classifier_grad_fn(gap_classifier_vars, features)
    alpha = grads.mean(axis=(0,1))
    #print("features:", features.shape)
    #print("grads:", grads.shape)
    #print("alpha:", alpha.shape)
    relevance = jnp.einsum("hwc,c->hw", features, alpha)
    relevance = jnp.maximum(relevance, 0)
    #print("relevance:", relevance.shape)
    # resize to input image size
    relevance_resized = jax.image.resize(
        relevance, (H, W), method="bilinear"
    )
    relevance_resized = normalize_max(relevance_resized)
    #print("relevance_resized:", relevance_resized.shape)
    # logits: [num_classes]
    # grad:   [H, W]
    return logits, relevance_resized


def load_resnet_for_grad_cam(size):
    @jax.jit
    def backbone_fn(variables, img):
        # img:   [H, W, C], float32 in range [0, 1]
        # feats: [h, w, c], float32
        img = normalize_for_resnet(img)
        feats = backbone.apply(variables, img[None, ...], mutable=False)[0]
        return feats

    @jax.jit
    def gap_classifier_fn(variables, feats):
        # feats:  [h, w, c], float32
        # logit:  float32
        # logits: [10], float32
        logits = gap_classifier.apply(variables, feats[None, ...], mutable=False)[0]
        logits = imagenet_to_imagenette_logits(logits)
        return logits.max(), logits

    ResNet, variables = jax_resnet.pretrained_resnet(size)
    model = ResNet()

    backbone = nn.Sequential(model.layers[:-2])
    backbone_vars = jax_resnet.slice_variables(variables, start=0, end=-2)
    gap_classifier = nn.Sequential(model.layers[-2:])
    gap_classifier_vars = jax_resnet.slice_variables(variables, start=len(model.layers) - 2, end=None)
    return (
        flax.core.freeze({"backbone": backbone_fn, "gap_cls": gap_classifier_fn}),
        flax.core.freeze({"backbone": backbone_vars, "gap_cls": gap_classifier_vars}),
    )


grad_cam_fn = jax.jit(grad_cam_fn, static_argnames=["fns"])
grad_cam_fn = jax.vmap(grad_cam_fn, in_axes=(None, None, 0))

images, labels = viz_batch
fns, variables = load_resnet_for_grad_cam(size=18)
logits, relevance = grad_cam_fn(fns, variables, images)

show_images(
    images * relevance[..., None],
    labels,
    logits,
)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
relevance: (7, 7)
No description has been provided for this image

Deletion score¶

So far, we have evaluated the explanations qualitatively by drawing them as heatmaps over a few sample images. A common metric used to evaluate the correctness of an explanation method quantitatively is the deletion score. It is computed by progressively removing pixels from an image in order of importance and measuring the corresponding drop in confidence. The behavior can be visualized on a plot that has the percentage of removed pixels on the horizontal axis and the prediction confidence on the vertical axis.

Ideally, if the most-relevant pixels are actually important for the prediction, their removal should induce a sudden drop in confidence for that label. To summarize this idea with a number we can compute the area under the curve: a low area indicates a quick decline in confidence, hence a good explanation method. This value is denoted as deletion score.

Single-image deletion score¶

Task 8¶

Implement a function prepare_deletion that given an image and the associated relevance prepares a batch of steps images. If we indicate the resulting batch as imgs, not that:

  • imgs[0] corresponds to the input image
  • imgs[s] is a copy of the input image with s/(steps-1) percent of black pixels
  • imgs[-1] is an all-black image
  • pixels are set to zero in order of relevance with the most-relevant ones first
  • if a pixel (i, j) is set to black at step s it will remain black in all subsequent steps, i.e. the images become progressively more black

The code below samples a random relevance mask and shows the images resulting from prepare_deletion so that you can verify the implementation. Check that the first regions to become black are the ones with the highest relevance.

Tips:

  • It's easier to reason about the flattened versions of the image and the relevance matrices, use jnp.ndarray.flatten and jnp.unravel_index to move back and forth between one and two dimensions
  • use jnp.argsort and jnp.array_split to sort and split the relevance, but be careful about sorting in ascending/descending order
In [15]:
def prepare_deletion(img, relevance, steps: int):
    assert relevance.shape == img.shape[:2]
    H, W, _ = img.shape
    imgs = jnp.tile(img, (steps, 1, 1, 1))
    #print(f"imgs:{imgs.shape}")
    relevance = relevance.flatten()
    indices = jnp.argsort(relevance)[::-1]
    # chunks of index sections to mask for each step
    prev_img = img
    idx_sections = jnp.array_split(indices, steps-1)
    for s in range(1, steps):  
        idxs = idx_sections[s-1]
        i, j = jnp.unravel_index(idxs, (H,W))
        curr_img = prev_img.at[i, j, :].set(0)
        imgs = imgs.at[s].set(curr_img)
        prev_img = curr_img
    # imgs: [steps, H, W, 3]
    return imgs


prepare_deletion = jax.jit(prepare_deletion, static_argnames="steps")

relevance = jax.random.uniform(jax.random.PRNGKey(42), (7, 7))
relevance = jax.image.resize(relevance, (224, 224), method="bilinear")
relevance = normalize_zero_one(relevance)
#print(f"relevance shape:{relevance.shape}")
#print(f"relevance:{relevance}")

image = plt.get_cmap('viridis')(relevance)[..., :3]

steps=8
images = prepare_deletion(image, relevance, steps)
assert images.shape == (steps, *image.shape)
show_images(images, width_one_img_inch=2)
No description has been provided for this image

Task 9¶

Using the prepare_deletion function implemented above, complete the missing lines of deletion_score_fn:

  • The function takes as input an image, its relevance, and a number of steps
  • The function returns a vector of length steps containing the probabilities associated to the top-scoring class predicted by the model as more and more relevant pixels are removed
  • The function returns the original prediction pred_orig too

The cell below contains a few lines of code for plotting the resulting curve and the associated score. You should see the confidence curve slowly decreasing to zero and eventually rising up slighlty.

Tips:

  • The expected value for the area under the curve is 0.270
  • You don't need for loops, use vmap
In [16]:
def deletion_score_fn(logits_fn, variables, img, relevance, steps):
    H, W, _ = img.shape
    # original prediction
    _, logits_orig = logits_fn(variables, img)
    pred_orig = logits_orig.argmax()
    #jax.debug.print("pred_orig: {}", pred_orig)
    probs_orig = nn.softmax(logits_orig)
    prob_orig = probs_orig[logits_orig.argmax()]
    #jax.debug.print("prob_orig: {}", prob_orig)
    # imgs with deleted pixels
    imgs = prepare_deletion(img, relevance, steps)
    # get probs for each step
    _, logits_all = jax.vmap(logits_fn, (None, 0))(variables, imgs)
    probs_all = nn.softmax(logits_all, axis=-1)
    #print(f"probs_all shape:{probs_all.shape}")
    #jax.debug.print("probs_all: {}", probs_all)
    probs = probs_all[:, pred_orig]
    #jax.debug.print("probs: {}", probs)
    # probs: [steps]
    # pred_orig: int
    return probs, pred_orig


deletion_score_fn = jax.jit(deletion_score_fn, static_argnames=["logits_fn", "steps"])

image = viz_batch[0][2]
relevance = jax.random.uniform(jax.random.PRNGKey(42), (7, 7))
relevance = jax.image.resize(relevance, image.shape[:2], method="bilinear")
relevance = normalize_zero_one(relevance)

steps = 8
logits_fn, variables = load_resnet(size=18)
probs, pred = deletion_score_fn(logits_fn, variables, image, relevance, steps)
auc = sklearn.metrics.auc(np.linspace(0,1,steps), probs)
assert len(probs) == steps

fig, ax = plt.subplots(1, 1, figsize=(9, 4))
ax.fill_between(np.linspace(0, 1, steps), probs)
ax.set_ylabel(f"Confidence for '{CLASS_NAMES[pred]}'")
ax.grid(axis="y")
ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
ax.set_xlabel("Pixels removed")
ax.set_title(f"Deletion score: {auc:.3f}");
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
No description has been provided for this image

Visualize deletion on some images¶

The following code visualizes the explanations of all methods for the eight images used so far. On the right the deletion curve and its area is also plotted.

Task 10¶

No implementation required, just take a moment to compare the explanations:

  • How do they look side-by-side? Does the deletion curve match what you see in the image?
  • Which one do you trust the most? Motivate your feeling!
  • Does your judgement correlate well with the deletion score?
  • Is there a method that is consistently better?

Add your comments below:

  • Yes, the deletion curves generally agree with the visual explanations. Smooth, focused heatmaps correspond to steep drops (better attribution). For example, in images (like gas pump) where the red region is concise and on the object (like Grad-CAM or IG), the confidence drops quickly — low deletion scores. Whereas, for church images, when the heatmap is noisy or off-target (like raw Grad), the curve fluctuates or stays high longer.
  • Looking at the images, I trust Grad-CAM as it clearly focuses on the key semantic regions of each object (e.g., the center of the gas pump, body of the truck, or middle of the horn), making it easy to interpret visually. In addition, Integrated Gradients gives similar relevance but with smoother transitions and less noise than plain Grad or Grad×Input. Whereas, Occlusion is interpretable but coarse — it blocks large regions, hence losing spatial precision.
  • Yes, Images with object-aligned heatmaps (IG) also have low deletion scores, meaning it correctly identified critical pixels. Methods with noisy or diffuse heatmaps (Grad, Grad×Input) show higher and fluctuating curves, confirming less stable explanations. Grad-CAM focuses on interpretable, high-level regions rather than pixel-exact importance. So, even though the heatmaps look convincing, they have high deletion scores as it does not focus on sharp, pixel-level sensitivity. Similarly, the deletion behavior shows that occlusion finds the right regions but not the exact pixels. Confidence drops slowly because each removed patch mixes relevant and irrelevant pixels — confirming that occlusion is reliable but coarse and less efficient compared to gradient-based methods.
  • Yes, Integrated Gradients consistently shows the least deletion score, confirming quantitative reliability. Grad-CAM is consistently strong visually and gives interpretable, smooth heatmaps.
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [28]:
logits_fn, variables = load_resnet(size=18)
logits_fns_gc, variables_gc = load_resnet_for_grad_cam(size=18)

# Use lambdas instead of partial because vmap doesn't play well with kwargs
all_methods = {
    "occlusion": lambda images: occlusion_fn(logits_fn, variables, images, 6),
    "grad": lambda images: grad_norm_fn(logits_fn, variables, images),
    "grad_x_input": lambda images: grad_x_input_fn(logits_fn, variables, images),
    "grad_cam": lambda images: grad_cam_fn(logits_fns_gc, variables_gc, images),
    "integrated_gradients": lambda images: integrated_grad_fn(logits_fn, variables, images, 20),
}

# These logits are only needed for visualization
images, labels = viz_batch
_, logits = jax.vmap(logits_fn, (None, 0))(variables, images)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
In [34]:
deletion_score_fn_vmap = jax.vmap(deletion_score_fn, in_axes=(None, None, 0, 0, None))

fig, axs = plt.subplots(
    len(images),
    len(all_methods) + 1,
    figsize=((len(all_methods) + 3) * 3, len(images) * 3),
    gridspec_kw={"width_ratios": len(all_methods) * [1] + [3], "wspace": 0.001},
)

# Write true/predicted class on the left side of each row
for ax, lb, lg in zip(axs[:, 0], labels, logits):
    ax.set_ylabel(f"True {CLASS_NAMES[lb]}\nPred {CLASS_NAMES[lg.argmax()]}")

# Column headers: method names + deletion curve
for ax, method in zip(axs[0, :-1], all_methods.keys()):
    ax.set_title(method)
axs[0, -1].set_title("Deletion curve")

# Each explanation method gets its own column on the left and its own curve on the right
for method_col, (method, method_fn) in zip(axs[:, :-1].T, all_methods.items()):
    _, relevance = method_fn(images)
    for ax, img, rel in zip(method_col, images, relevance):
        ax.imshow(blend(img, RED, normalize_zero_one(rel)[..., None]))

    probs, _ = deletion_score_fn_vmap(logits_fn, variables, images, relevance, 25)
    for ax, p in zip(axs[:, -1], probs):
        auc = sklearn.metrics.auc(np.linspace(0, 1, len(p)), p)
        ax.plot(np.linspace(0, 1, len(p)), p, label=f"{auc:.3f} {method}")

# Remove inner ticks for the image grid
for ax in axs[:-1, :-1].flat:
    ax.set_xticks([])
for ax in axs[:, 1:-1].flat:
    ax.set_yticks([])

# Annotate right column on the rightmost edge
for ax in axs[:, -1]:
    axt = ax.twinx()
    axt.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
    axt.set_ylabel("Confidence")
    axt.grid()
    ax.set_yticks([])
    ax.legend(loc="upper right", framealpha=1.0)

# Remove pixel percent ticks from right column, except at the bottom
for ax in axs[:-1, -1]:
    ax.set_xticklabels([])
axs[-1, -1].xaxis.set_major_formatter(mpl.ticker.PercentFormatter(xmax=1))
axs[-1, -1].set_xlabel("Pixels removed");
No description has been provided for this image

Average deletion score on entire dataset¶

To better compare the explanation methods, we can compute the average deletion score across the entire dataset. This value should give us an indication of which method is best at identifying relevant pixels for a prediction.

Task 11¶

No implementation required, just consider the results of this evaluation:

  • Which method seems to be best?
  • Can you trust results with such a high standard deviation?
  • What can be the cause of it? Think both of how the metric is computed and of how the content of an image might affect the score.

Add your comments below:

  • Integrated Gradients seems to the best with the least mean deletion score, indicating that Integrated Gradients seems to identify the most relevant pixels overall.
  • The standard deviations are quite large, almost comparable with the mean value, meaning the deletion score varies a lot from one image to another and that we should be cautious when interpreting the ranking.
  • One of the reasons for high standard deviation could be that the deletion score depends on how quickly the model’s confidence drops as pixels are removed. This can vary dramatically between images depending on the object’s size, location, and background complexity. Images where the object occupies most of the frame (e.g., a church filling the image) will show slow confidence decay, whereas small, localized objects (like a gas pump) will drop sharply.

Warning: the following code might take a long time to run and/or run out of memory. For reference, on a single GPU with 10 GB of memory and batch size 32, the total time for all the loops is approximately 10 minutes. You may need to reduce the batch size or limit the number of images.

In [35]:
deletion_steps = 10

avg_auc = []
deletion_steps_arr = np.linspace(0, 1, deletion_steps)
logits_fn, variables = load_resnet(size=18)
fns_gc, variables_gc = load_resnet_for_grad_cam(size=18)

ds = ds_builder.as_dataset(split="train", batch_size=None, as_supervised=True)
ds = ds.map(resize, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True)
ds = ds.batch(32, drop_remainder=True, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True)
ds = ds.take(50)  # limit number of batches
ds = ds.prefetch(4)
ds = tfds.as_numpy(ds)

for method, method_fn in all_methods.items():
    aucs = []
    for images, labels in tqdm.tqdm(ds, ncols=0, desc=method):
        _, relevance = method_fn(images)
        probs, _ = deletion_score_fn_vmap(logits_fn, variables, images, relevance, deletion_steps)
        aucs.extend(sklearn.metrics.auc(deletion_steps_arr, p) for p in probs)
    avg_auc.append({"method": method, "mean": np.mean(aucs), "std": np.std(aucs)})
    
avg_auc = pd.DataFrame(avg_auc)
display(avg_auc.set_index("method"))

fig, ax = plt.subplots(1, 1, figsize=(8, 4), facecolor="white")
avg_auc.plot(
    "method",
    "mean",
    yerr="std",
    kind="bar",
    rot=0,
    figsize=(10, 5),
    legend=None,
    ylim=(0, 1),
    xlabel="",
    title="Average deletion score (lower is better)",
    ax=ax
)
display(fig)
plt.close(fig)
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
Using cache found in /Users/silpasoninallacheruvu/.cache/torch/hub/pytorch_vision_v0.10.0
occlusion: 100% 50/50 [15:33<00:00, 19.52s/it]2025-10-07 14:24:25.507543: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
occlusion: 100% 50/50 [15:33<00:00, 18.67s/it]
grad: 100% 50/50 [04:28<00:00,  5.36s/it]2025-10-07 14:28:53.893985: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
grad: 100% 50/50 [04:28<00:00,  5.37s/it]
grad_x_input: 100% 50/50 [04:33<00:00,  5.42s/it]2025-10-07 14:33:27.296047: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
grad_x_input: 100% 50/50 [04:33<00:00,  5.47s/it]
grad_cam: 100% 50/50 [03:59<00:00,  4.78s/it]2025-10-07 14:37:26.942168: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
grad_cam: 100% 50/50 [03:59<00:00,  4.79s/it]
integrated_gradients: 100% 50/50 [41:13<00:00, 49.47s/it]2025-10-07 15:18:40.911704: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
integrated_gradients: 100% 50/50 [41:14<00:00, 49.48s/it]
mean std
method
occlusion 0.328927 0.227129
grad 0.257688 0.231876
grad_x_input 0.219738 0.185848
grad_cam 0.307005 0.206437
integrated_gradients 0.210839 0.185165
No description has been provided for this image

Conclusion¶

To help us improve this practical:

  • How long did it take to complete this notebook? 27 hrs
  • What was the most difficult part? Grad-CAM implementation took quite some time. And, the interpretation of each method.